from tianshou.data import Batch
import numpy as np
from ActualCausal.Utils.weighting import get_weights
from ActualCausal.Updater.pseudo_null import infer_pseudo_null
from Network.network_utils import pytorch_model
import time
from Model.model_utils import set_net_parameters
from ACState.object_dict import ObjDict
from collections import deque

def update_net_parameters(i, model, args, params):
    # print("updating net parameters", args.train.param_update_frequency >= 0, i % args.train.param_update_frequency == 0)
    if args.train.param_update_frequency >= 0:
        if i % args.train.param_update_frequency == 0:
            set_net_parameters(model, args, params)

def update_running_log_prob(params, args, result, maxlen=1000):
    if "running_log_prob" not in params: params.running_log_prob = deque(maxlen=maxlen)
    if result is None: return
    name = args.inter.train_names[0] if len(args.inter.train_names) > 0 else "all"
    for model_form in result.keys():
        train_form = list(result[model_form][name].keys())
        if "gradients" in train_form: train_form.remove("gradients")
        if len(train_form) > 0: train_form = train_form[0]
        # print("log_prob", list(result[model_form][name].keys()))
        if "log_probs" in result[model_form][name]:
            params.running_log_prob.append(pytorch_model.unwrap(result[train_form][name].log_probs.mean()))
        # print(list(result[model_form][name].keys()))
        else:
            train_forms = list(result[model_form][name].keys())
            for train_form in train_forms:
                if type(result[model_form][name][train_form]) is Batch and "log_probs" in result[model_form][name][train_form]:
                    # print("searched lp", train_form, list(result[model_form][name][train_form].keys()))
                    params.running_log_prob.append(pytorch_model.unwrap(result[model_form][name][train_form].log_probs.mean()))

def update_running_mask_logits(params, args, result, maxlen=1000):
    if "running_mask_logits" not in params: params.running_mask_logits = deque(maxlen=maxlen)
    if result is None: return
    name = args.inter.train_names[0] if len(args.inter.train_names) > 0 else "all"
    for model_form in result.keys():
        # print("mask_logits", list(result[model_form][name].keys()))
        if "mask_logits" in result[model_form][name]:
            if type(result[model_form][name].mask_logits) == np.ndarray: params.running_mask_logits.append(result[model_form][name].mask_logits.var(axis=0).max())
            else: params.running_mask_logits.append(pytorch_model.unwrap(result[model_form][name].mask_logits.var(dim=0).max()))
        else:
            train_forms = list(result[model_form][name].keys())
            for train_form in train_forms:
                if type(result[model_form][name][train_form]) is Batch and "mask_logits" in result[model_form][name][train_form]:
                    # print("searched", train_form, list(result[model_form][name][train_form].keys()))
                    params.running_mask_logits.append(pytorch_model.unwrap(result[model_form][name][train_form].mask_logits.var(dim=0).max()))

def adaptive_reset(adaptive_reset, params, base_reset, done_resetting, result):
	# Adaptively reset the interaction network if 
    #                       the variance of interactions plateaus and low, 
    #                       the log likelihood appears to be converged (low variance), 
    #                       the combination of those
    if not done_resetting or adaptive_reset[-1] <= 0: return False
    # TODO: make reset probabilistic
    log_prob_reset = np.var(params.running_log_prob) < adaptive_reset[0]
    inter_reset = np.var(params.running_mask_logits) < adaptive_reset[1]
    inter_val_reset = np.mean(params.running_mask_logits) < adaptive_reset[2]
    return base_reset and log_prob_reset and inter_reset and inter_val_reset


def compute_params(i, args, buffer, pretrain=False, result=None, params = None, weight_binaries=None, passive_weight_binaries=None, model=None):
    if params is None: 
        params = ObjDict()
        params.total_masking_steps = 0
    # TODO: this probably doesnt' change, but if it does we should note that here
    if "converged_active_loss_value" in args.factor:
        if result is not None:
            for k in result.keys():
                if args.inter.train_names:
                    # print(k, args.inter.train_names, list(result[k].keys()))
                    if result[k][args.inter.train_names[0]] is not None and "full" in result[k][args.inter.train_names[0]]:
                        if "converged_active_loss_value" in params:
                            # print("updating", list(result.active[args.inter.train_names[0]].mask.keys()), list(result.active[args.inter.train_names[0]].keys())) 
                            # print(pytorch_model.unwrap(result.active[args.inter.train_names[0]].full.log_probs.sum(dim=-1).mean()), params.converged_active_loss_value)
                            params.converged_active_loss_value = max(params.converged_active_loss_value, pytorch_model.unwrap(result[k][args.inter.train_names[0]].full.log_probs.sum(dim=-1).mean()))
        else: params.converged_active_loss_value = args.factor.converged_active_loss_value
        
    if args.train.expectile.estimate_threshold and "converged_active_loss_value" in args.factor:
        params.expt_threshold = args.factor.converged_active_loss_value + args.train.expectile.expt_threshold
    else: params.expt_threhold = args.train.expectile.expt_threshold

    # regularization for embeddings and null TODO: put these on a schedule
    params.embed_reg = args.inter.regularization.embedding.embed_reg_lambda
    params.null_embed_reg = args.inter.regularization.null_consistency.null_reg_lambda
    # infer in training on a subset
    params.infer_num = args.infer.infer_num
    # update active loss weighting parameter
    params.active_full_weight = np.power(0.5, (max(0, i - args.active.delay_inter_train)/args.active.interaction_schedule)) if args.active.interaction_schedule > 1 else (0.5 if args.active.interaction_schedule < 0 else args.active.interaction_schedule)
    # update number of steps for interaction
    params.masking_steps = int(max(args.masking.inline_iters[1], min(args.masking.inline_iters[0],
                                         np.power(2, (i/args.masking.inline_iters[2])) - 1) if args.masking.inline_iters[2] > 1 
                                         else args.masking.inline_iters[0]))
    params.total_masking_steps += params.masking_steps # the actual number is always     params.total_masking_steps - params.masking_steps 
 
    # update sampling frequency parameters
    params.sample_passive_weight_lambda = (args.passive.weighting[0] * np.power(0.5, (i/args.passive.weighting[1]))) if args.passive.weighting[1] > 0 else (args.passive.weighting[0])
    params.sample_active_weight_lambda = (args.active.weighting[0] * np.power(0.5, (i/args.active.weighting[1]))) if args.active.weighting[1] > 0 else (args.active.weighting[0])
    params.sample_active_full_weight_lambda = (args.active.full_weighting[0] * np.power(0.5, (i/args.active.full_weighting[1]))) if args.active.full_weighting[1] > 0 else (args.active.full_weighting[0])
    params.sample_inter_weight_lambda = (args.masking.weighting[0] * np.power(0.5, (i/args.masking.weighting[1]))) if args.masking.weighting[1] > 0 else (args.masking.weighting[0])
    params.sample_low_weight_lambda = (args.masking.low_weighting[0] * np.power(0.5, (i/args.masking.low_weighting[1]))) if args.masking.low_weighting[1] > 0 else (args.masking.low_weighting[0])
    # update interaction loss parameters
    params.entropy_lambda = float((args.masking.entropy_weight[0] * np.power(0.5, (i/args.masking.entropy_weight[1]))) if args.masking.entropy_weight[1] > 0 else (args.masking.entropy_weight[0]))
    params.lasso_one_lambda = float((args.masking.oneloss[0] * np.power(0.5, (i/args.masking.oneloss[1]))) if args.masking.oneloss[1] > 0 else (args.masking.oneloss[0]) )
    params.lasso_half_lambda = float((args.masking.halfloss[0] * np.power(0.5, (i/args.masking.halfloss[1]))) if args.masking.halfloss[1] > 0 else (args.masking.halfloss[0]) )
    params.lasso = float((args.masking.lasso[0] * (1-np.power(0.5, (i * 3.0/args.masking.lasso[1])))) if args.masking.lasso[1] > 0 else (args.masking.lasso[0]) )
    params.adaptive_lasso = float((args.masking.adaptive_lasso[0] * (1-np.power(0.5, (i * 3.0/args.masking.adaptive_lasso[1])))) if args.masking.adaptive_lasso[1] > 0 else (args.masking.adaptive_lasso[0]) )
    # other scheduled losses
    params.random_mask_rate = float((args.active.random_masks.random_mask_schedule[0] * (1-np.power(0.5, (i/args.active.random_masks.random_mask_schedule[1])))) if args.active.random_masks.random_mask_schedule[1] > 0 else (args.active.random_masks.random_mask_schedule[0]) )
    params.soft_mask_param =  float((args.active.soft_masking[0] * (1-np.power(0.5, (i/args.active.soft_masking[1])))) if args.active.soft_masking[1] > 0 else (args.active.soft_masking[0]) )

    # reset logic, if true triggers a reset
    update_running_mask_logits(params, args, result, maxlen=args.active.resetting.running_maxlen)
    update_running_log_prob(params, args, result, maxlen=args.active.resetting.running_maxlen)
    params.reset_inter = ((i % args.active.resetting.reset_inter[0]) == args.active.resetting.reset_inter[1]) and i < args.active.resetting.reset_inter[2]
    params.reset_active = ((i % args.active.resetting.reset_active[0]) == args.active.resetting.reset_active[1]) and i < args.active.resetting.reset_active[2]
    params.reset_inter = adaptive_reset(args.active.resetting.adaptive_inter_reset, params, params.reset_inter, i < args.active.resetting.reset_inter[2], result)
    params.reset_active = adaptive_reset(args.active.resetting.adaptive_active_reset, params, params.reset_active, i < args.active.resetting.reset_active[2], result)
    # updating the weights can be expensive, so only perform at frequency
    if args.train.param_update_frequency >= 0:
        if i % args.train.param_update_frequency == 0:
            if passive_weight_binaries is None:
                try:
                    passive_weight_binaries = buffer.passive_weight_binary
                except AttributeError as e:
                    pass
            if buffer is not None and (passive_weight_binaries is not None):
                wb = buffer.passive_weight_binary if passive_weight_binaries is None else passive_weight_binaries
                params.sample_passive_weights = get_weights(params.sample_passive_weight_lambda, wb)
            else:
                params.sample_passive_weights = None
            if weight_binaries is None:
                try:
                    weight_binaries = buffer.weight_binary
                except AttributeError as e:
                    pass
            if buffer is not None and (weight_binaries is not None):
                wb = buffer.weight_binary if weight_binaries is None else weight_binaries
                params.sample_interaction_weights = get_weights(params.sample_inter_weight_lambda, wb)
                params.sample_low_interaction_weights = get_weights(params.sample_low_weight_lambda, wb)
                params.sample_active_weights = get_weights(params.sample_active_weight_lambda, wb)
                params.sample_active_full_weights = get_weights(params.sample_active_full_weight_lambda, wb)
                print("weighting lambda", params.sample_active_full_weight_lambda, params.sample_active_weight_lambda)
            else:
                params.sample_interaction_weights, params.sample_active_weights, params.sample_active_full_weights = None, None, None
            if not pretrain and buffer is not None and model is not None:
                infer_pseudo_null(args, model, buffer, params)
    else:
        params.sample_passive_weights = None
        params.sample_interaction_weights, params.sample_active_weights, params.sample_active_full_weights, params.sample_low_interaction_weights = None, None, None, None

    
    if pretrain: # variables that change between runs
        params.mask_mode = args.infer.pretrain_mask_mode
    else:
        params.mask_mode = args.infer.train_mask_mode

    return params